import pandas as pd
import numpy as np
from copy import deepcopy, copy
import cvxpy as cvx

class linear_bandits(object):

    def __init__(self, dict_params, dt_env, prefix_sep="_zl_"):
        self.dict_params = dict_params
        self.current_iter = 0
        self.actions = []
        self.conversions = []
        self.rewards = []
        self.costs2 = []
        self.costs1 = []
        self.prefix_sep = prefix_sep

        self._init_sequence_simu(dt_env)
        self._init_constant()

        np.random.seed(self.dict_params["seed"])

    def _init_constant(self):
        self.lmd = self.dict_params["lmd"]
        self.UCB_multiply = self.dict_params["UCB_multiply"]
        self.norm_costs1 = self.dict_params["norm_costs1"]
        self.eta_oco = self.dict_params["eta_oco"]
        self.current_theta2 = 0
        self.current_theta1 = 0

        if "Z" in self.dict_params.keys():
            self.Z = self.dict_params["Z"]
            self.given_Z = True
        else:
            self.given_Z = False

    def _init_sequence_simu(self, dt):
        if dt.shape[0] != self.dict_params["T"]:
            print(f"Sampling data to reach {self.dict_params['T']}")
            self.dt = dt.sample(self.dict_params["T"], replace=True).reset_index(drop=True)
        else:
            self.dt = dt

        dt_dummy = \
            pd.get_dummies(dt[[self.dict_params["var_rate"]] + self.dict_params["var_context"] + \
                              self.dict_params["var_base_reward_costs"]], prefix_sep=self.prefix_sep)

        var_add = [x for x in self.dict_params["var_model_onehot"] if x not in dt_dummy.columns]

        for var_ in var_add:
            dt_dummy[var_] = 0

        self.dt_dummy = dt_dummy[self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]]

    def _get_context(self):
        self.current_context = deepcopy(self.dt.iloc[self.current_iter:self.current_iter + 1].reset_index(drop=True))
        self.current_context_dummy = deepcopy(
            self.dt_dummy.iloc[self.current_iter:self.current_iter + 1].reset_index(drop=True))

        self.backlog_context = deepcopy(self.dt.iloc[:self.current_iter + 1].reset_index(drop=True))
        self.backlog_context_dummy = deepcopy(self.dt_dummy.iloc[:self.current_iter + 1].reset_index(drop=True))

    def _cal_constant_UCB(self):
        self.current_constant_UCB = (np.log(self.current_iter + 1) + 1) * self.UCB_multiply

    def _update_weights(self):
        if self.current_iter == 0:
            self.vt_raw = np.outer(self.current_context_dummy.values, self.current_context_dummy.values)
            self.x_r  = self.current_context_dummy.values * self.current_reward
            self.x_c2 = self.current_context_dummy.values * self.current_cost2
            self.x_c1 = self.current_context_dummy.values * self.current_cost1
        else:
            self.vt_raw += np.outer(self.current_context_dummy.values, self.current_context_dummy.values)
            self.x_r += self.current_context_dummy.values * self.current_reward
            self.x_c2 += self.current_context_dummy.values * self.current_cost2
            self.x_c1 +=  self.current_context_dummy.values * self.current_cost1

        i = self.current_context_dummy.values.shape[1]
        self.vt = self.vt_raw + np.identity(i) * self.lmd

        self.reward_weights = np.inner(np.linalg.inv(self.vt), self.x_r).T
        self.cost2_weights = np.inner(np.linalg.inv(self.vt), self.x_c2).T
        self.cost1_weights = np.inner(np.linalg.inv(self.vt), self.x_c1).T

    def _cal_bonus_UCB(self, action, context, matrix_context):
        current_bonus_UCB = \
            self.current_constant_UCB * np.sqrt(
                np.sum(np.array(matrix_context * np.linalg.inv(self.vt)) * context, axis=1))

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"For action {action}: Average bonus {np.round(np.mean(current_bonus_UCB), 4)}")

        return current_bonus_UCB

    def _take_action(self, type_action):
        if type_action == "anull":
            action = -1
        elif type_action == "random":
            list_acitons_ = self.dict_params["list_actions"][1:]
            action = np.random.choice(list_acitons_)
        elif type_action == "exploitation":
            action = self._action_max_weighted_reward()
        else:
            raise ValueError("type_action is wrong")

        self.actions.append(action)
        self.current_action = action
        return action

    def _update_action_context(self):
        if self.current_action != -1:
            self.current_context_dummy[self.dict_params["var_rate"]] = \
                self.current_context_dummy[self.dict_params["var_rate"]] * (1 - self.current_action / 100)
            self.current_context_dummy["discount"] = self.current_action / 100

    def _get_reward_costs(self):
        if self.current_action == -1:
            self.conversions.append(0)
            self.rewards.append(0)
            self.costs2.append(0)
            self.costs1.append(0)
        else:
            discount_ = self.current_action / 100
            dt_ = deepcopy(self.current_context)
            dt_["pred_conv"] = \
                self.dict_params["model_conversion"].predict_proba(
                    self.current_context_dummy[self.dict_params["var_model_onehot"]])[:, 1]
            dt_["conversion"] = dt_.apply(lambda x: np.random.rand() < x["pred_conv"], axis=1) * 1
            dt_["reward"] = dt_["conversion"] * dt_["amount_norm"]
            dt_["cost2"] = dt_["conversion"] * dt_["discount_base_norm"] * discount_
            dt_["cost1"] = (dt_[f"conversion"] * discount_) / self.norm_costs1

            if self.dict_params["verbose"] is True:
                if self.current_iter % 50 == 0:
                    print(f"current conversion: {np.round(dt_['conversion'][0], 4)}")
                    print(f"current reward: {np.round(dt_['reward'][0], 4)}")
                    print(f"current cost2: {np.round(dt_['cost2'][0], 4)}")
                    print(f"current cost1: {np.round(dt_['cost1'][0], 4)}")

            self.conversions.append(dt_["conversion"][0])
            self.rewards.append(dt_["reward"][0])
            self.costs2.append(dt_["cost2"][0])
            self.costs1.append(dt_["cost1"][0])

            self.current_reward = dt_["reward"][0]
            self.current_cost2 = dt_["cost2"][0]
            self.current_cost1 = dt_["cost1"][0]

    def _check_break_constraints(self):
        total_costs2 = np.sum(self.costs2)
        total_costs1 = np.sum(self.costs1)

        if (total_costs2 > self.dict_params["budget"] - 1) | (total_costs1 > self.dict_params["budget"] - 1):
            return True
        else:
            return False

    def _calculate_exploration_gain_costs_matrix(self):
        if self.current_iter == self.dict_params["n_random_action"]:
            context_exp_dummy_ = deepcopy(self.backlog_context_dummy)
            context_exp_ = deepcopy(self.backlog_context)
        else:
            context_exp_dummy_ = deepcopy(self.current_context_dummy)
            context_exp_ = deepcopy(self.current_context)

        context_exp_dummy_ = \
            context_exp_dummy_[[x for x in self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]]]

        context_exp_dummy_[f"{self.dict_params['var_rate']}_SAVE"] = context_exp_dummy_[self.dict_params["var_rate"]]
        context_exp_dummy_[f"discount_SAVE"] = context_exp_dummy_["discount"]

        for i_ in self.dict_params["list_actions"]:
            # compute rate
            if i_ == -1:
                context_exp_[f"volume_{i_}"] = 0
                context_exp_[f"discount_amount_{i_}"] = 0
                context_exp_[f"discount_sum_{i_}"] = 0
            else:
                discount_ = i_ / 100
                context_exp_dummy_[self.dict_params["var_rate"]] = \
                    context_exp_dummy_[f"{self.dict_params['var_rate']}_SAVE"] * (1 - discount_)

                context_exp_dummy_["discount"] = discount_

                context_exp_[f"volume_{i_}"] = \
                    np.squeeze(np.inner(context_exp_dummy_.values, self.reward_weights))

                context_exp_[f"discount_amount_{i_}"] = \
                    np.squeeze(np.inner(context_exp_dummy_.values, self.cost2_weights))

                context_exp_[f"discount_sum_{i_}"] = \
                    np.squeeze(np.inner(context_exp_dummy_.values, self.cost1_weights))

        context_exp_dummy_[self.dict_params["var_rate"]] = context_exp_dummy_[f"{self.dict_params['var_rate']}_SAVE"]
        context_exp_dummy_["discount"] = context_exp_dummy_[f"discount_SAVE"]

        context_exp_dummy_ = context_exp_dummy_.drop([f"{self.dict_params['var_rate']}_SAVE", "discount_SAVE"], axis=1)

        if self.current_iter == self.dict_params["n_random_action"]:
            self.backlog_matrix_exploration = context_exp_
        else:
            self.backlog_matrix_exploration = pd.concat([self.backlog_matrix_exploration, context_exp_],
                                                        axis=0).reset_index(drop=True)

        if self.current_iter == self.dict_params["T0"] - 1:
            self.target_matrix_exploration = np.array(
                self.backlog_matrix_exploration[[f"volume_{i}" for i in self.dict_params["list_actions"]]])
            self.constraint1_matrix_exploration = np.array(
                self.backlog_matrix_exploration[[f"discount_amount_{i}" for i in self.dict_params["list_actions"]]])
            self.constraint2_matrix_exploration = np.array(
                self.backlog_matrix_exploration[[f"discount_sum_{i}" for i in self.dict_params["list_actions"]]])

    def _OCO_calculate_theta(self):
        B_dev_T = self.dict_params["budget"] / self.dict_params["T"]
        theta2_raw = self.current_theta2 + self.eta_oco * (self.current_cost2 - B_dev_T)
        theta1_raw = self.current_theta1 + self.eta_oco * (self.current_cost1 - B_dev_T)

        current_theta1 = cvx.Variable()
        current_theta2 = cvx.Variable()

        obj_func = (current_theta2 - theta2_raw) ** 2 + (current_theta1 - theta1_raw) ** 2

        constraint_sum_1 = (cvx.abs(current_theta2) + cvx.abs(current_theta1)) <= 1
        constraint_non_negative_2 = current_theta2 >= 0
        constraint_non_negative_1 = current_theta1 >= 0

        constraints = [constraint_sum_1, constraint_non_negative_2, constraint_non_negative_1]

        obj = cvx.Minimize(obj_func)
        prob = cvx.Problem(obj, constraints)

        try:
            prob.solve(verbose=False)
        except:
            prob.solve(verbose=False, solver="SCS")

        self.current_theta2 = current_theta2.value
        self.current_theta1 = current_theta1.value

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"The theta2 is {self.current_theta2}")
                print(f"The theta1 is {self.current_theta1}")

    def _action_max_weighted_reward(self):
        self.current_context_dummy[f"{self.dict_params['var_rate']}_SAVE"] = \
            self.current_context_dummy[self.dict_params["var_rate"]]

        self.current_context_dummy[f"discount_SAVE"] = self.current_context_dummy["discount"]

        best_action = -2
        best_bonus = -999999

        list_action_no_null = [x for x in self.dict_params["list_actions"] if x != -1]
        var_model_onehot_final = [x for x in self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]]

        for i_ in list_action_no_null:
            discount_ = i_ / 100
            self.current_context_dummy[self.dict_params["var_rate"]] = \
                self.current_context_dummy[f"{self.dict_params['var_rate']}_SAVE"] * (1 - discount_)
            self.current_context_dummy["discount"] = discount_

            context_ = self.current_context_dummy[
                self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]].values
            matrix_context_ = np.matrix(context_)

            current_bonus_UCB_ = self._cal_bonus_UCB(action=i_, context=context_, matrix_context=matrix_context_)

            reward_ucb = \
                np.squeeze(np.inner(self.current_context_dummy[var_model_onehot_final].values, self.reward_weights)) + current_bonus_UCB_

            cost2_lcb = \
                np.squeeze(np.inner(self.current_context_dummy[var_model_onehot_final].values, self.cost2_weights)) - current_bonus_UCB_

            cost1_lcb = \
                np.squeeze(np.inner(self.current_context_dummy[var_model_onehot_final].values, self.cost1_weights)) - current_bonus_UCB_

            bonus_ = reward_ucb - self.Z * (cost2_lcb * self.current_theta2 + cost1_lcb * self.current_theta1)

            if (bonus_ > best_bonus) | ((bonus_ == best_bonus) & (np.random.rand() > 0.5)):
                best_action = i_
                best_bonus = bonus_

        self.current_context_dummy[self.dict_params["var_rate"]] = \
            self.current_context_dummy[f"{self.dict_params['var_rate']}_SAVE"]

        self.current_context_dummy["discount"] = self.current_context_dummy["discount_SAVE"]

        self.current_context_dummy = \
            self.current_context_dummy.drop([f"{self.dict_params['var_rate']}_SAVE", "discount_SAVE"], axis=1)

        return best_action

    def _run_one_iteration(self):
        self._get_context()

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"Iteration {self.current_iter}")

        break_constraint = self._check_break_constraints()

        if break_constraint == True:
            self._take_action(type_action="anull")
        elif self.current_iter <= (self.dict_params["n_random_action"] - 1):
            self._take_action(type_action="random")
        else:
            self._take_action(type_action="exploitation")

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"current action: {self.current_action}")

        self._update_action_context()
        self._get_reward_costs()

        if break_constraint is False:
            self._update_weights()

            if self.current_iter >= (self.dict_params["n_random_action"] - 1):
                self._cal_constant_UCB()
                self._OCO_calculate_theta()

        self.current_iter += 1

    def run_simulation(self):
        for i_ in range(self.dict_params["T"]):
            self._run_one_iteration()
            if self.dict_params["verbose"] is True:
                if self.current_iter % 50 == 0:
                    print(f"Cumulative Rewards: {np.round(np.sum(self.rewards), 4)}")
                    print(f"Cumulative Costs2: {np.round(np.sum(self.costs2), 4)}")
                    print(f"Cumulative Costs1: {np.round(np.sum(self.costs1), 4)}")
                    print("------------------------------------")